#!/usr/bin/env python3

"""
Given a linked list and a value x, partition it such that all nodes less than x come before
nodes greater than or equal to x
"""

class ListNode:
    def __init__(self, x):
        self.val = x
        self.next= None

    def link(self, node):
        self.next = node
        return node

    def show(self, end='\n'):
        print(str(self.val), end=' ')
        if self.next:
            self.next.show(end=' ')
        print(end=end)


class Solution:
    def partition(self, head, x):
        result = None
        less_tail = None
        greater_head = None
        greater_tail = None

        current = head
        while current is not None:
            if current.val >= x:
                if greater_head is None:
                    greater_head = greater_tail = current
                else:
                    greater_tail.next = current
                    greater_tail = current
            else:
                if result is None:
                    result = less_tail = current
                else:
                    less_tail.next = current
                    less_tail = current
            current = current.next

        if greater_tail is not None:
            greater_tail.next = None
        if result is None:
            return greater_head
        else:
            less_tail.next = greater_head
            return result


if __name__ == '__main__':
    head = ListNode(1)
    head.link(ListNode(4)).link(ListNode(3)).link(ListNode(2)).link(ListNode(5)).link(ListNode(2))
    head.show()

    result = Solution().partition(head, 3)
    result.show()
